Skip to content

Commit dd506e8

Browse files
authored
Add instance pool id to WorkspaceConfig (#1087)
1 parent 567638c commit dd506e8

File tree

5 files changed

+26
-18
lines changed

5 files changed

+26
-18
lines changed

src/databricks/labs/ucx/install.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def _configure_new_installation(self) -> WorkspaceConfig:
201201
log_level = self._prompts.question("Log level", default="INFO").upper()
202202
num_threads = int(self._prompts.question("Number of threads", default="8", valid_number=True))
203203

204-
policy_id, instance_profile, spark_conf_dict = self._policy_installer.create(inventory_database)
204+
policy_id, instance_profile, spark_conf_dict, instance_pool_id = self._policy_installer.create(
205+
inventory_database
206+
)
205207

206208
# Check if terraform is being used
207209
is_terraform_used = self._prompts.confirm("Do you use Terraform to deploy your infrastructure?")
@@ -220,6 +222,7 @@ def _configure_new_installation(self) -> WorkspaceConfig:
220222
instance_profile=instance_profile,
221223
spark_conf=spark_conf_dict,
222224
policy_id=policy_id,
225+
instance_pool_id=instance_pool_id,
223226
is_terraform_used=is_terraform_used,
224227
include_databases=self._select_databases(),
225228
)

src/databricks/labs/ucx/installer/policy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, installation: Installation, ws: WorkspaceClient, prompts: Pro
2222
def _policy_config(value: str):
2323
return {"type": "fixed", "value": value}
2424

25-
def create(self, inventory_database: str) -> tuple[str, str, dict]:
25+
def create(self, inventory_database: str) -> tuple[str, str, dict, str | None]:
2626
instance_profile = ""
2727
spark_conf_dict = {}
2828
# get instance pool id to be put into the cluster policy
@@ -52,7 +52,7 @@ def create(self, inventory_database: str) -> tuple[str, str, dict]:
5252
logger.info(f"Cluster policy {policy_name} already present, reusing the same.")
5353
policy_id = policy.policy_id
5454
assert policy_id is not None
55-
return policy_id, instance_profile, spark_conf_dict
55+
return policy_id, instance_profile, spark_conf_dict, instance_pool_id
5656
logger.info("Creating UCX cluster policy.")
5757
policy_id = self._ws.cluster_policies.create(
5858
name=policy_name,
@@ -64,6 +64,7 @@ def create(self, inventory_database: str) -> tuple[str, str, dict]:
6464
policy_id,
6565
instance_profile,
6666
spark_conf_dict,
67+
instance_pool_id,
6768
)
6869

6970
def _get_instance_pool_id(self) -> str | None:
@@ -72,7 +73,7 @@ def _get_instance_pool_id(self) -> str | None:
7273
"Instance pool id to be set in cluster policy for all workflow clusters", default="None"
7374
)
7475
except OSError:
75-
# when unit test v0.15.0_added_cluster_policy.py MockPromots cannot be injected to ClusterPolicyInstaller
76+
# when unit test v0.15.0_added_cluster_policy.py MockPrompts cannot be injected to ClusterPolicyInstaller
7677
# return None to pass the test
7778
return None
7879
if instance_pool_id.lower() == "none":

src/databricks/labs/ucx/upgrades/v0.15.0_added_cluster_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def upgrade(installation: Installation, ws: WorkspaceClient):
1717
config = installation.load(WorkspaceConfig)
1818
policy_installer = ClusterPolicyInstaller(installation, ws, Prompts())
19-
config.policy_id, _, _ = policy_installer.create(config.inventory_database)
19+
config.policy_id, _, _, _ = policy_installer.create(config.inventory_database)
2020
installation.save(config)
2121
states = InstallState.from_installation(installation)
2222
assert config.policy_id is not None

tests/integration/test_installation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import logging
44
import os.path
5+
import sys
56
from collections.abc import Callable
67
from dataclasses import replace
78
from datetime import timedelta
@@ -474,9 +475,9 @@ def test_check_inventory_database_exists(ws, new_installation):
474475
def test_table_migration_job( # pylint: disable=too-many-locals
475476
ws, new_installation, make_catalog, make_schema, make_table, env_or_skip, make_random, make_dbfs_data_copy
476477
):
477-
# skip this test if not in nightly test job: TEST_NIGHTLY is missing or is not set to "true"
478-
if env_or_skip("TEST_NIGHTLY").lower() != "true":
479-
pytest.skip("TEST_NIGHTLY is not true")
478+
# skip this test if not in nightly test job or debug mode
479+
if os.path.basename(sys.argv[0]) not in {"_jb_pytest_runner.py", "testlauncher.py"}:
480+
env_or_skip("TEST_NIGHTLY")
480481
# create external and managed tables to be migrated
481482
src_schema = make_schema(catalog_name="hive_metastore")
482483
src_managed_table = make_table(schema_name=src_schema.name)

tests/unit/installer/test_policy.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_cluster_policy_definition_present_reuse():
4444
ws, prompts = common()
4545

4646
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
47-
policy_id, _, _ = policy_installer.create('ucx')
47+
policy_id, _, _, _ = policy_installer.create('ucx')
4848
assert policy_id is not None
4949
ws.cluster_policies.create.assert_not_called()
5050

@@ -73,7 +73,7 @@ def test_cluster_policy_definition_azure_hms():
7373
)
7474
]
7575
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
76-
policy_id, _, _ = policy_installer.create('ucx')
76+
policy_id, _, _, _ = policy_installer.create('ucx')
7777
policy_definition_actual = {
7878
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
7979
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
@@ -114,7 +114,7 @@ def test_cluster_policy_definition_aws_glue():
114114
]
115115

116116
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
117-
policy_id, instance_profile, _ = policy_installer.create('ucx')
117+
policy_id, instance_profile, _, _ = policy_installer.create('ucx')
118118
policy_definition_actual = {
119119
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
120120
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
@@ -155,7 +155,7 @@ def test_cluster_policy_definition_gcp():
155155
]
156156

157157
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
158-
policy_id, instance_profile, _ = policy_installer.create('ucx')
158+
policy_id, instance_profile, _, _ = policy_installer.create('ucx')
159159
policy_definition_actual = {
160160
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
161161
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
@@ -257,7 +257,7 @@ def test_cluster_policy_definition_azure_hms_warehouse():
257257
}
258258
)
259259
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
260-
policy_id, _, _ = policy_installer.create('ucx')
260+
policy_id, _, _, _ = policy_installer.create('ucx')
261261
policy_definition_actual = {
262262
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
263263
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
@@ -309,7 +309,7 @@ def test_cluster_policy_definition_aws_glue_warehouse():
309309
}
310310
)
311311
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
312-
policy_id, instance_profile, _ = policy_installer.create('ucx')
312+
policy_id, instance_profile, _, _ = policy_installer.create('ucx')
313313
policy_definition_actual = {
314314
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
315315
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
@@ -364,7 +364,7 @@ def test_cluster_policy_definition_gcp_hms_warehouse():
364364
}
365365
)
366366
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
367-
policy_id, _, _ = policy_installer.create('ucx')
367+
policy_id, _, _, _ = policy_installer.create('ucx')
368368
policy_definition_actual = {
369369
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
370370
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
@@ -405,7 +405,7 @@ def test_cluster_policy_definition_empty_config():
405405
]
406406

407407
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
408-
policy_id, _, _ = policy_installer.create('ucx')
408+
policy_id, _, _, _ = policy_installer.create('ucx')
409409
policy_definition_actual = {
410410
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
411411
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
@@ -431,7 +431,9 @@ def test_cluster_policy_instance_pool():
431431
ws.config.is_gcp = False
432432

433433
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
434-
policy_installer.create('ucx')
434+
_, _, _, instance_pool_id = policy_installer.create('ucx')
435+
436+
assert instance_pool_id == "instance_pool_1"
435437

436438
policy_expected = {
437439
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
@@ -452,7 +454,8 @@ def test_cluster_policy_instance_pool():
452454
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
453455
"aws_attributes.availability": {"type": "fixed", "value": "ON_DEMAND"},
454456
}
455-
policy_installer.create('ucx')
457+
_, _, _, instance_pool_id = policy_installer.create('ucx')
458+
assert instance_pool_id is None
456459
ws.cluster_policies.create.assert_called_with(
457460
name="Unity Catalog Migration (ucx) ([email protected])",
458461
definition=json.dumps(policy_expected),

0 commit comments

Comments
 (0)