diff --git a/src/xpk/commands/cluster.py b/src/xpk/commands/cluster.py index e1895f8b4..4ac769548 100644 --- a/src/xpk/commands/cluster.py +++ b/src/xpk/commands/cluster.py @@ -81,6 +81,7 @@ from ..utils.templates import get_templates_absolute_path import shutil import os +import time CLUSTER_PREHEAT_JINJA_FILE = 'cluster_preheat.yaml.j2' @@ -407,6 +408,8 @@ def cluster_create(args) -> None: # pylint: disable=line-too-long f' https://console.cloud.google.com/kubernetes/clusters/details/{get_cluster_location(args.project, args.cluster, args.zone)}/{args.cluster}/details?project={args.project}' ) + if args.managed_mldiagnostics: + install_mldiagnostics_prerequisites() xpk_exit(0) @@ -1319,3 +1322,271 @@ def prepare_gpus(system: SystemCharacteristics): err_code = disable_mglru_on_cluster() if err_code > 0: xpk_exit(err_code) + + +def install_cert_manager(version: str = 'v1.13.0'): + """ + Apply the cert-manager manifest. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = ( + 'kubectl apply -f' + ' https://github.com/cert-manager/cert-manager/releases/download/' + f'{version}/cert-manager.yaml' + ) + + return_code = run_command_with_updates( + command, f'Applying cert-manager {version} manifest...' + ) + + if return_code != 0: + xpk_print(f'Applying cert-manager returned with ERROR {return_code}.\n') + + return return_code + + +def download_mldiagnostics_yaml(package_name: str, version: str): + """ + Downloads the mldiagnostics injection webhook YAML from Artifact Registry. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = ( + 'gcloud artifacts generic download' + ' --repository=mldiagnostics-webhook-and-operator-yaml --location=us' + f' --package={package_name} --version={version} --destination=./' + ' --project=ai-on-gke' + ) + + return_code, return_output = run_command_for_value( + command, + f'Starting gcloud artifacts download for {package_name} {version}...', + ) + + if return_code != 0: + if 'already exists' in return_output: + xpk_print( + f'Artifact file for {package_name} {version} already exists locally.' + ' Skipping download.' + ) + return 0 + xpk_print(f'gcloud download returned with ERROR {return_code}.\n') + xpk_exit(return_code) + + xpk_print('Artifact download completed successfully.') + return return_code + + +def create_mldiagnostics_namespace(): + """ + Creates the 'gke-mldiagnostics' namespace. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = 'kubectl create namespace gke-mldiagnostics' + + return_code, return_output = run_command_for_value( + command, 'Starting kubectl create namespace...' + ) + + if return_code != 0: + if 'already exists' in return_output: + xpk_print('Namespace already exists locally. Skipping creation.') + return 0 + xpk_print(f'Namespace creation returned with ERROR {return_code}.\n') + xpk_exit(return_code) + + xpk_print('gke-mldiagnostics Namespace created or already exists.') + return return_code + + +def install_mldiagnostics_yaml(artifact_filename: str): + """ + Applies the mldiagnostics injection webhook YAML manifest. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = f'kubectl apply -f {artifact_filename} -n gke-mldiagnostics' + + return_code = run_command_with_updates( + command, + f'Starting kubectl apply -f {artifact_filename} -n gke-mldiagnostics...', + ) + + if return_code != 0: + xpk_print(f'kubectl apply returned with ERROR {return_code}.\n') + xpk_exit(return_code) + + xpk_print(f'{artifact_filename} applied successfully.') + + if os.path.exists(artifact_filename): + try: + os.remove(artifact_filename) + xpk_print(f'Successfully deleted local file: {artifact_filename}') + + except PermissionError: + xpk_print( + f'Failed to delete file {artifact_filename} due to Permission Error.' + ) + + else: + xpk_print( + f'File {artifact_filename} does not exist locally. Skipping deletion' + ' (Cleanup assumed).' + ) + + return return_code + + +def label_default_namespace_mldiagnostics(): + """ + Labels the 'default' namespace with 'managed-mldiagnostics-gke=true'. + + Returns: + 0 if successful and 1 otherwise. + """ + + command = 'kubectl label namespace default managed-mldiagnostics-gke=true' + + return_code = run_command_with_updates( + command, + 'Starting kubectl label namespace default with' + ' managed-mldiagnostics-gke=true...', + ) + + if return_code != 0: + xpk_print(f'Namespace labeling returned with ERROR {return_code}.\n') + xpk_exit(return_code) + + xpk_print('default Namespace successfully labeled.') + return return_code + + +def install_mldiagnostics_prerequisites(): + """ + Mldiagnostics installation requirements. + + Returns: + 0 if successful and 1 otherwise. + """ + deployment_name = 'kueue-controller-manager' + namespace_name = 'kueue-system' + cert_webhook_deployment_name = 'cert-manager-webhook' + cert_webhook_namespace_name = 'cert-manager' + # is_running = wait_for_cluster_running(args) + is_running = wait_for_deployment_ready(deployment_name, namespace_name) + time.sleep(30) + if is_running: + return_code = install_cert_manager() + if return_code != 0: + return return_code + + cert_webhook_ready = wait_for_deployment_ready( + cert_webhook_deployment_name, cert_webhook_namespace_name + ) + time.sleep(30) + if cert_webhook_ready: + + webhook_package = 'mldiagnostics-injection-webhook' + webhook_version = 'v0.5.0' + webhook_filename = f'{webhook_package}-{webhook_version}.yaml' + + return_code = download_mldiagnostics_yaml( + package_name=webhook_package, version=webhook_version + ) + if return_code != 0: + return return_code + + return_code = create_mldiagnostics_namespace() + if return_code != 0: + return return_code + + return_code = install_mldiagnostics_yaml( + artifact_filename=webhook_filename + ) + if return_code != 0: + return return_code + + return_code = label_default_namespace_mldiagnostics() + if return_code != 0: + return return_code + + # --- Install Operator --- + operator_package = 'mldiagnostics-connection-operator' + operator_version = 'v0.5.0' + operator_filename = f'{operator_package}-{operator_version}.yaml' + + return_code = download_mldiagnostics_yaml( + package_name=operator_package, version=operator_version + ) + if return_code != 0: + return return_code + + return_code = install_mldiagnostics_yaml( + artifact_filename=operator_filename + ) + if return_code != 0: + return return_code + + xpk_print( + 'All mldiagnostics installation and setup steps have been' + ' successfully completed!' + ) + return return_code + else: + xpk_print('The cert-manager-webhook installation failed.') + xpk_exit(1) + else: + xpk_print( + f'Application {deployment_name} failed to become ready within the' + ' timeout.' + ) + xpk_exit(1) + + +def wait_for_deployment_ready( + deployment_name: str, namespace: str, timeout_seconds: int = 300 +) -> bool: + """ + Polls the Kubernetes Deployment status using kubectl rollout status + until it successfully rolls out (all replicas are ready) or times out. + + Args: + deployment_name: The name of the Kubernetes Deployment (e.g., 'kueue-controller-manager'). + namespace: The namespace where the Deployment is located (e.g., 'kueue-system'). + timeout_seconds: Timeout duration in seconds (default is 300s / 5 minutes). + + Returns: + bool: True if the Deployment successfully rolled out, False otherwise (timeout or error). + """ + + command = ( + f'kubectl rollout status deployment/{deployment_name} -n {namespace}' + f' --timeout={timeout_seconds}s' + ) + + print( + f'Waiting for deployment {deployment_name} in namespace {namespace} to' + ' successfully roll out...' + ) + + return_code, return_output = run_command_for_value( + command, f'Checking status of deployment {deployment_name}...' + ) + + if return_code != 0: + xpk_print(f'\nError: Deployment {deployment_name} failed to roll out.') + xpk_print(f'kubectl output: {return_output}') + return False + + xpk_print(f'Success: Deployment {deployment_name} successfully rolled out.') + return True diff --git a/src/xpk/commands/cluster_test.py b/src/xpk/commands/cluster_test.py index 153ace154..105219521 100644 --- a/src/xpk/commands/cluster_test.py +++ b/src/xpk/commands/cluster_test.py @@ -20,7 +20,7 @@ from unittest.mock import MagicMock, patch import pytest -from xpk.commands.cluster import _install_kueue, _validate_cluster_create_args, run_gke_cluster_create_command +from xpk.commands.cluster import _install_kueue, _validate_cluster_create_args, run_gke_cluster_create_command, install_mldiagnostics_prerequisites from xpk.core.system_characteristics import SystemCharacteristics, UserFacingNameToSystemCharacteristics from xpk.core.testing.commands_tester import CommandsTester from xpk.utils.feature_flags import FeatureFlags @@ -56,6 +56,9 @@ def mocks(mocker) -> _Mocks: run_command_with_updates_path=( 'xpk.commands.cluster.run_command_with_updates' ), + run_command_for_value_path=( + 'xpk.commands.cluster.run_command_for_value' + ), ), ) @@ -87,6 +90,7 @@ def construct_args(**kwargs: Any) -> Namespace: memory_limit='100Gi', cpu_limit=100, cluster_cpu_machine_type='', + managed_mldiagnostics=False, ) args_dict.update(kwargs) return Namespace(**args_dict) @@ -247,3 +251,58 @@ def test_run_gke_cluster_create_command_with_gke_version_has_no_autoupgrade_flag mocks.commands_tester.assert_command_run( 'clusters create', ' --no-enable-autoupgrade' ) + + +def test_install_mldiagnostics_prerequisites_commands_executed( + mocks: _Mocks, + mocker, +): + mock_sleep = mocker.patch('time.sleep', return_value=None) + + mock_wait_ready = mocker.patch( + 'xpk.commands.cluster.wait_for_deployment_ready', return_value=True + ) + mock_install_cert = mocker.patch( + 'xpk.commands.cluster.install_cert_manager', return_value=0 + ) + mock_download = mocker.patch( + 'xpk.commands.cluster.download_mldiagnostics_yaml', return_value=0 + ) + mock_create_ns = mocker.patch( + 'xpk.commands.cluster.create_mldiagnostics_namespace', return_value=0 + ) + mock_install_yaml = mocker.patch( + 'xpk.commands.cluster.install_mldiagnostics_yaml', return_value=0 + ) + mock_label_ns = mocker.patch( + 'xpk.commands.cluster.label_default_namespace_mldiagnostics', + return_value=0, + ) + + mocker.patch('os.path.exists', return_value=True) + mocker.patch('os.remove') + + install_mldiagnostics_prerequisites() + + mock_wait_ready.assert_any_call('kueue-controller-manager', 'kueue-system') + + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(30) + + mock_install_cert.assert_called_once() + + mock_wait_ready.assert_any_call('cert-manager-webhook', 'cert-manager') + + assert mock_download.call_count == 2 + mock_download.assert_any_call( + package_name='mldiagnostics-injection-webhook', version='v0.5.0' + ) + mock_download.assert_any_call( + package_name='mldiagnostics-connection-operator', version='v0.5.0' + ) + + mock_create_ns.assert_called_once() + + assert mock_install_yaml.call_count == 2 + + mock_label_ns.assert_called_once() diff --git a/src/xpk/parser/cluster.py b/src/xpk/parser/cluster.py index e0d5af7bd..e05e56ac5 100644 --- a/src/xpk/parser/cluster.py +++ b/src/xpk/parser/cluster.py @@ -150,6 +150,18 @@ def set_cluster_create_parser(cluster_create_parser: ArgumentParser): ' enable cluster to accept Pathways workloads.' ), ) + + cluster_create_optional_arguments.add_argument( + '--managed-mldiagnostics', + action='store_true', + default=False, + help=( + '[Optional] Enables the installation of required ML Diagnostics' + ' components: cert-manager, injection-webhook, and' + ' connection-operator. This feature is OFF by default.' + ), + ) + if FeatureFlags.SUB_SLICING_ENABLED: cluster_create_optional_arguments.add_argument( '--sub-slicing', @@ -222,6 +234,17 @@ def set_cluster_create_pathways_parser( ), ) + cluster_create_pathways_required_arguments.add_argument( + '--managed-mldiagnostics', + action='store_true', + default=False, + help=( + '[Optional] Enables the installation of required ML Diagnostics' + ' components: cert-manager, injection-webhook, and' + ' connection-operator. This feature is OFF by default.' + ), + ) + ### Optional arguments specific to "cluster create-pathways" cluster_create_pathways_optional_arguments = ( cluster_create_pathways_parser.add_argument_group( diff --git a/src/xpk/parser/cluster_test.py b/src/xpk/parser/cluster_test.py index 2b2706b4f..55bdb7ac6 100644 --- a/src/xpk/parser/cluster_test.py +++ b/src/xpk/parser/cluster_test.py @@ -64,3 +64,18 @@ def test_cluster_create_sub_slicing_can_be_set(): ) assert args.sub_slicing is True + + +def test_cluster_create_managed_mldiagnostics(): + parser = argparse.ArgumentParser() + + set_cluster_create_parser(parser) + args = parser.parse_args([ + "--cluster", + "test-cluster", + "--tpu-type", + "v5p-8", + "--managed-mldiagnostics", + ]) + + assert args.managed_mldiagnostics is True