diff --git a/xlml/apis/task.py b/xlml/apis/task.py index fb4453d74..b60a71e01 100644 --- a/xlml/apis/task.py +++ b/xlml/apis/task.py @@ -102,7 +102,7 @@ def run_queued_resource_test( tpu_name = tpu.generate_tpu_name( task_test_config.benchmark_id, tpu_name_env_var ) - ssh_keys = ssh.generate_ssh_keys() + ssh_keys = ssh.obtain_persist_ssh_keys() output_location = name_format.generate_gcs_folder_location( task_test_config.gcs_subfolder, task_test_config.benchmark_id, diff --git a/xlml/utils/ssh.py b/xlml/utils/ssh.py index 6d13fc012..a5b52fb24 100644 --- a/xlml/utils/ssh.py +++ b/xlml/utils/ssh.py @@ -17,6 +17,7 @@ import dataclasses from airflow.decorators import task +from airflow.models import Variable from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -27,6 +28,7 @@ class SshKeys: private: str public: str + user: str @task @@ -46,4 +48,25 @@ def generate_ssh_keys() -> SshKeys: serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH ) - return SshKeys(private=private_key.decode(), public=public_key.decode()) + return SshKeys( + private=private_key.decode(), + public=public_key.decode(), + user="ml-auto-solutions", + ) + + +@task +def obtain_persist_ssh_keys() -> SshKeys: + """Obtain persistent SSH keys and user from Airflow Variables.""" + try: + user = Variable.get("os-login-ssh-user") + private_key = Variable.get("os-login-ssh-private-key") + public_key = Variable.get("os-login-ssh-public-key") + except KeyError as e: + raise ValueError( + f"Required Airflow Variable {e} is not set. " + "Please ensure 'os-login-ssh-user', 'os-login-ssh-private-key', " + "and 'os-login-ssh-public-key' are configured." + ) from e + + return SshKeys(private=private_key, public=public_key, user=user) diff --git a/xlml/utils/tpu.py b/xlml/utils/tpu.py index a1501fe6c..c08648e71 100644 --- a/xlml/utils/tpu.py +++ b/xlml/utils/tpu.py @@ -128,7 +128,7 @@ def create_queued_resource_request( ) metadata = { - 'ssh-keys': f'ml-auto-solutions:{ssh_keys.public}', + 'ssh-keys': f'{ssh_keys.user}:{ssh_keys.public}', 'startup-script': startup_script_command, } @@ -374,16 +374,6 @@ def ssh_tpu( client.get_node(name=os.path.join(node.parent, 'nodes', node.node_id)) for node in queued_resource.tpu.node_spec ] - node_metadata = nodes[0].metadata - is_oslogin_enabled = node_metadata.get('enable-oslogin', '') == 'TRUE' - - user = 'ml-auto-solutions' - if is_oslogin_enabled: - logging.info('Auto-detected OS Login enabled on node {nodes[0].name}..') - # get private key from Airflow Variable - user = Variable.get('os-login-ssh-user') - ssh_keys.private = Variable.get('os-login-ssh-private-key') - ssh_keys.public = Variable.get('os-login-ssh-public-key') if all_workers: endpoints = itertools.chain.from_iterable( @@ -407,7 +397,7 @@ def ssh_tpu( *ip_addresses, connect_kwargs={ 'auth_strategy': paramiko.auth_strategy.InMemoryPrivateKey( - user, pkey + ssh_keys.user, pkey ), # See https://stackoverflow.com/a/59453832 'banner_timeout': 200,