Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions src/xpk/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from ..utils.templates import get_templates_absolute_path
import shutil
import os
import time

CLUSTER_PREHEAT_JINJA_FILE = 'cluster_preheat.yaml.j2'

Expand Down Expand Up @@ -407,6 +408,8 @@ def cluster_create(args) -> None:
# pylint: disable=line-too-long
f' https://console.cloud.google.com/kubernetes/clusters/details/{get_cluster_location(args.project, args.cluster, args.zone)}/{args.cluster}/details?project={args.project}'
)
if args.managed_mldiagnostics:
install_mldiagnostics_prerequisites()
xpk_exit(0)


Expand Down Expand Up @@ -1319,3 +1322,271 @@ def prepare_gpus(system: SystemCharacteristics):
err_code = disable_mglru_on_cluster()
if err_code > 0:
xpk_exit(err_code)


def install_cert_manager(version: str = 'v1.13.0'):
"""
Apply the cert-manager manifest.

Returns:
0 if successful and 1 otherwise.
"""

command = (
'kubectl apply -f'
' https://github.com/cert-manager/cert-manager/releases/download/'
f'{version}/cert-manager.yaml'
)

return_code = run_command_with_updates(
command, f'Applying cert-manager {version} manifest...'
)

if return_code != 0:
xpk_print(f'Applying cert-manager returned with ERROR {return_code}.\n')

return return_code


def download_mldiagnostics_yaml(package_name: str, version: str):
"""
Downloads the mldiagnostics injection webhook YAML from Artifact Registry.

Returns:
0 if successful and 1 otherwise.
"""

command = (
'gcloud artifacts generic download'
' --repository=mldiagnostics-webhook-and-operator-yaml --location=us'
f' --package={package_name} --version={version} --destination=./'
' --project=ai-on-gke'
)

return_code, return_output = run_command_for_value(
command,
f'Starting gcloud artifacts download for {package_name} {version}...',
)

if return_code != 0:
if 'already exists' in return_output:
xpk_print(
f'Artifact file for {package_name} {version} already exists locally.'
' Skipping download.'
)
return 0
xpk_print(f'gcloud download returned with ERROR {return_code}.\n')
xpk_exit(return_code)

xpk_print('Artifact download completed successfully.')
return return_code


def create_mldiagnostics_namespace():
"""
Creates the 'gke-mldiagnostics' namespace.

Returns:
0 if successful and 1 otherwise.
"""

command = 'kubectl create namespace gke-mldiagnostics'

return_code, return_output = run_command_for_value(
command, 'Starting kubectl create namespace...'
)

if return_code != 0:
if 'already exists' in return_output:
xpk_print('Namespace already exists locally. Skipping creation.')
return 0
xpk_print(f'Namespace creation returned with ERROR {return_code}.\n')
xpk_exit(return_code)

xpk_print('gke-mldiagnostics Namespace created or already exists.')
return return_code


def install_mldiagnostics_yaml(artifact_filename: str):
"""
Applies the mldiagnostics injection webhook YAML manifest.

Returns:
0 if successful and 1 otherwise.
"""

command = f'kubectl apply -f {artifact_filename} -n gke-mldiagnostics'

return_code = run_command_with_updates(
command,
f'Starting kubectl apply -f {artifact_filename} -n gke-mldiagnostics...',
)

if return_code != 0:
xpk_print(f'kubectl apply returned with ERROR {return_code}.\n')
xpk_exit(return_code)

xpk_print(f'{artifact_filename} applied successfully.')

if os.path.exists(artifact_filename):
try:
os.remove(artifact_filename)
xpk_print(f'Successfully deleted local file: {artifact_filename}')

except PermissionError:
xpk_print(
f'Failed to delete file {artifact_filename} due to Permission Error.'
)

else:
xpk_print(
f'File {artifact_filename} does not exist locally. Skipping deletion'
' (Cleanup assumed).'
)

return return_code


def label_default_namespace_mldiagnostics():
"""
Labels the 'default' namespace with 'managed-mldiagnostics-gke=true'.

Returns:
0 if successful and 1 otherwise.
"""

command = 'kubectl label namespace default managed-mldiagnostics-gke=true'

return_code = run_command_with_updates(
command,
'Starting kubectl label namespace default with'
' managed-mldiagnostics-gke=true...',
)

if return_code != 0:
xpk_print(f'Namespace labeling returned with ERROR {return_code}.\n')
xpk_exit(return_code)

xpk_print('default Namespace successfully labeled.')
return return_code


def install_mldiagnostics_prerequisites():
"""
Mldiagnostics installation requirements.

Returns:
0 if successful and 1 otherwise.
"""
deployment_name = 'kueue-controller-manager'
namespace_name = 'kueue-system'
cert_webhook_deployment_name = 'cert-manager-webhook'
cert_webhook_namespace_name = 'cert-manager'
# is_running = wait_for_cluster_running(args)
is_running = wait_for_deployment_ready(deployment_name, namespace_name)
time.sleep(30)
if is_running:
return_code = install_cert_manager()
if return_code != 0:
return return_code

cert_webhook_ready = wait_for_deployment_ready(
cert_webhook_deployment_name, cert_webhook_namespace_name
)
time.sleep(30)
if cert_webhook_ready:

webhook_package = 'mldiagnostics-injection-webhook'
webhook_version = 'v0.5.0'
webhook_filename = f'{webhook_package}-{webhook_version}.yaml'

return_code = download_mldiagnostics_yaml(
package_name=webhook_package, version=webhook_version
)
if return_code != 0:
return return_code

return_code = create_mldiagnostics_namespace()
if return_code != 0:
return return_code

return_code = install_mldiagnostics_yaml(
artifact_filename=webhook_filename
)
if return_code != 0:
return return_code

return_code = label_default_namespace_mldiagnostics()
if return_code != 0:
return return_code

# --- Install Operator ---
operator_package = 'mldiagnostics-connection-operator'
operator_version = 'v0.5.0'
operator_filename = f'{operator_package}-{operator_version}.yaml'

return_code = download_mldiagnostics_yaml(
package_name=operator_package, version=operator_version
)
if return_code != 0:
return return_code

return_code = install_mldiagnostics_yaml(
artifact_filename=operator_filename
)
if return_code != 0:
return return_code

xpk_print(
'All mldiagnostics installation and setup steps have been'
' successfully completed!'
)
return return_code
else:
xpk_print('The cert-manager-webhook installation failed.')
xpk_exit(1)
else:
xpk_print(
f'Application {deployment_name} failed to become ready within the'
' timeout.'
)
xpk_exit(1)


def wait_for_deployment_ready(
deployment_name: str, namespace: str, timeout_seconds: int = 300
) -> bool:
"""
Polls the Kubernetes Deployment status using kubectl rollout status
until it successfully rolls out (all replicas are ready) or times out.

Args:
deployment_name: The name of the Kubernetes Deployment (e.g., 'kueue-controller-manager').
namespace: The namespace where the Deployment is located (e.g., 'kueue-system').
timeout_seconds: Timeout duration in seconds (default is 300s / 5 minutes).

Returns:
bool: True if the Deployment successfully rolled out, False otherwise (timeout or error).
"""

command = (
f'kubectl rollout status deployment/{deployment_name} -n {namespace}'
f' --timeout={timeout_seconds}s'
)

print(
f'Waiting for deployment {deployment_name} in namespace {namespace} to'
' successfully roll out...'
)

return_code, return_output = run_command_for_value(
command, f'Checking status of deployment {deployment_name}...'
)

if return_code != 0:
xpk_print(f'\nError: Deployment {deployment_name} failed to roll out.')
xpk_print(f'kubectl output: {return_output}')
return False

xpk_print(f'Success: Deployment {deployment_name} successfully rolled out.')
return True
61 changes: 60 additions & 1 deletion src/xpk/commands/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from unittest.mock import MagicMock, patch
import pytest

from xpk.commands.cluster import _install_kueue, _validate_cluster_create_args, run_gke_cluster_create_command
from xpk.commands.cluster import _install_kueue, _validate_cluster_create_args, run_gke_cluster_create_command, install_mldiagnostics_prerequisites
from xpk.core.system_characteristics import SystemCharacteristics, UserFacingNameToSystemCharacteristics
from xpk.core.testing.commands_tester import CommandsTester
from xpk.utils.feature_flags import FeatureFlags
Expand Down Expand Up @@ -56,6 +56,9 @@ def mocks(mocker) -> _Mocks:
run_command_with_updates_path=(
'xpk.commands.cluster.run_command_with_updates'
),
run_command_for_value_path=(
'xpk.commands.cluster.run_command_for_value'
),
),
)

Expand Down Expand Up @@ -87,6 +90,7 @@ def construct_args(**kwargs: Any) -> Namespace:
memory_limit='100Gi',
cpu_limit=100,
cluster_cpu_machine_type='',
managed_mldiagnostics=False,
)
args_dict.update(kwargs)
return Namespace(**args_dict)
Expand Down Expand Up @@ -247,3 +251,58 @@ def test_run_gke_cluster_create_command_with_gke_version_has_no_autoupgrade_flag
mocks.commands_tester.assert_command_run(
'clusters create', ' --no-enable-autoupgrade'
)


def test_install_mldiagnostics_prerequisites_commands_executed(
mocks: _Mocks,
mocker,
):
mock_sleep = mocker.patch('time.sleep', return_value=None)

mock_wait_ready = mocker.patch(
'xpk.commands.cluster.wait_for_deployment_ready', return_value=True
)
mock_install_cert = mocker.patch(
'xpk.commands.cluster.install_cert_manager', return_value=0
)
mock_download = mocker.patch(
'xpk.commands.cluster.download_mldiagnostics_yaml', return_value=0
)
mock_create_ns = mocker.patch(
'xpk.commands.cluster.create_mldiagnostics_namespace', return_value=0
)
mock_install_yaml = mocker.patch(
'xpk.commands.cluster.install_mldiagnostics_yaml', return_value=0
)
mock_label_ns = mocker.patch(
'xpk.commands.cluster.label_default_namespace_mldiagnostics',
return_value=0,
)

mocker.patch('os.path.exists', return_value=True)
mocker.patch('os.remove')

install_mldiagnostics_prerequisites()

mock_wait_ready.assert_any_call('kueue-controller-manager', 'kueue-system')

assert mock_sleep.call_count == 2
mock_sleep.assert_any_call(30)

mock_install_cert.assert_called_once()

mock_wait_ready.assert_any_call('cert-manager-webhook', 'cert-manager')

assert mock_download.call_count == 2
mock_download.assert_any_call(
package_name='mldiagnostics-injection-webhook', version='v0.5.0'
)
mock_download.assert_any_call(
package_name='mldiagnostics-connection-operator', version='v0.5.0'
)

mock_create_ns.assert_called_once()

assert mock_install_yaml.call_count == 2

mock_label_ns.assert_called_once()
Loading
Loading