Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run_queued_resource_test(
tpu_name = tpu.generate_tpu_name(
task_test_config.benchmark_id, tpu_name_env_var
)
ssh_keys = ssh.generate_ssh_keys()
ssh_keys = ssh.obtain_persist_ssh_keys()
output_location = name_format.generate_gcs_folder_location(
task_test_config.gcs_subfolder,
task_test_config.benchmark_id,
Expand Down
25 changes: 24 additions & 1 deletion xlml/utils/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dataclasses

from airflow.decorators import task
from airflow.models import Variable
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

Expand All @@ -27,6 +28,7 @@ class SshKeys:

private: str
public: str
user: str


@task
Expand All @@ -46,4 +48,25 @@ def generate_ssh_keys() -> SshKeys:
serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
)

return SshKeys(private=private_key.decode(), public=public_key.decode())
return SshKeys(
private=private_key.decode(),
public=public_key.decode(),
user="ml-auto-solutions",
)


@task
def obtain_persist_ssh_keys() -> SshKeys:
"""Obtain persistent SSH keys and user from Airflow Variables."""
try:
user = Variable.get("os-login-ssh-user")
private_key = Variable.get("os-login-ssh-private-key")
public_key = Variable.get("os-login-ssh-public-key")
except KeyError as e:
raise ValueError(
f"Required Airflow Variable {e} is not set. "
"Please ensure 'os-login-ssh-user', 'os-login-ssh-private-key', "
"and 'os-login-ssh-public-key' are configured."
) from e

return SshKeys(private=private_key, public=public_key, user=user)
14 changes: 2 additions & 12 deletions xlml/utils/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def create_queued_resource_request(
)

metadata = {
'ssh-keys': f'ml-auto-solutions:{ssh_keys.public}',
'ssh-keys': f'{ssh_keys.user}:{ssh_keys.public}',
'startup-script': startup_script_command,
}

Expand Down Expand Up @@ -374,16 +374,6 @@ def ssh_tpu(
client.get_node(name=os.path.join(node.parent, 'nodes', node.node_id))
for node in queued_resource.tpu.node_spec
]
node_metadata = nodes[0].metadata
is_oslogin_enabled = node_metadata.get('enable-oslogin', '') == 'TRUE'

user = 'ml-auto-solutions'
if is_oslogin_enabled:
logging.info('Auto-detected OS Login enabled on node {nodes[0].name}..')
# get private key from Airflow Variable
user = Variable.get('os-login-ssh-user')
ssh_keys.private = Variable.get('os-login-ssh-private-key')
ssh_keys.public = Variable.get('os-login-ssh-public-key')

if all_workers:
endpoints = itertools.chain.from_iterable(
Expand All @@ -407,7 +397,7 @@ def ssh_tpu(
*ip_addresses,
connect_kwargs={
'auth_strategy': paramiko.auth_strategy.InMemoryPrivateKey(
user, pkey
ssh_keys.user, pkey
),
# See https://stackoverflow.com/a/59453832
'banner_timeout': 200,
Expand Down
Loading