Skip to content

Commit ba360e1

Browse files
authored
fix: Unify TPU SSH mechanism to resolve race conditions (#1182)
This change transitions the TPU SSH connection mechanism from ephemeral key injection to a persistent OS Login architecture. By leveraging long-lived SSH keys stored in the Service Account's OS Login profile, we eliminate the race conditions (409 Conflict) frequently encountered when running multiple concurrent TPU tasks in Airflow.
1 parent e5e5fba commit ba360e1

File tree

3 files changed

+27
-14
lines changed

3 files changed

+27
-14
lines changed

xlml/apis/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def run_queued_resource_test(
102102
tpu_name = tpu.generate_tpu_name(
103103
task_test_config.benchmark_id, tpu_name_env_var
104104
)
105-
ssh_keys = ssh.generate_ssh_keys()
105+
ssh_keys = ssh.obtain_persist_ssh_keys()
106106
output_location = name_format.generate_gcs_folder_location(
107107
task_test_config.gcs_subfolder,
108108
task_test_config.benchmark_id,

xlml/utils/ssh.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818

1919
from airflow.decorators import task
20+
from airflow.models import Variable
2021
from cryptography.hazmat.primitives import serialization
2122
from cryptography.hazmat.primitives.asymmetric import rsa
2223

@@ -27,6 +28,7 @@ class SshKeys:
2728

2829
private: str
2930
public: str
31+
user: str
3032

3133

3234
@task
@@ -46,4 +48,25 @@ def generate_ssh_keys() -> SshKeys:
4648
serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
4749
)
4850

49-
return SshKeys(private=private_key.decode(), public=public_key.decode())
51+
return SshKeys(
52+
private=private_key.decode(),
53+
public=public_key.decode(),
54+
user="ml-auto-solutions",
55+
)
56+
57+
58+
@task
59+
def obtain_persist_ssh_keys() -> SshKeys:
60+
"""Obtain persistent SSH keys and user from Airflow Variables."""
61+
try:
62+
user = Variable.get("os-login-ssh-user")
63+
private_key = Variable.get("os-login-ssh-private-key")
64+
public_key = Variable.get("os-login-ssh-public-key")
65+
except KeyError as e:
66+
raise ValueError(
67+
f"Required Airflow Variable {e} is not set. "
68+
"Please ensure 'os-login-ssh-user', 'os-login-ssh-private-key', "
69+
"and 'os-login-ssh-public-key' are configured."
70+
) from e
71+
72+
return SshKeys(private=private_key, public=public_key, user=user)

xlml/utils/tpu.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def create_queued_resource_request(
128128
)
129129

130130
metadata = {
131-
'ssh-keys': f'ml-auto-solutions:{ssh_keys.public}',
131+
'ssh-keys': f'{ssh_keys.user}:{ssh_keys.public}',
132132
'startup-script': startup_script_command,
133133
}
134134

@@ -374,16 +374,6 @@ def ssh_tpu(
374374
client.get_node(name=os.path.join(node.parent, 'nodes', node.node_id))
375375
for node in queued_resource.tpu.node_spec
376376
]
377-
node_metadata = nodes[0].metadata
378-
is_oslogin_enabled = node_metadata.get('enable-oslogin', '') == 'TRUE'
379-
380-
user = 'ml-auto-solutions'
381-
if is_oslogin_enabled:
382-
logging.info('Auto-detected OS Login enabled on node {nodes[0].name}..')
383-
# get private key from Airflow Variable
384-
user = Variable.get('os-login-ssh-user')
385-
ssh_keys.private = Variable.get('os-login-ssh-private-key')
386-
ssh_keys.public = Variable.get('os-login-ssh-public-key')
387377

388378
if all_workers:
389379
endpoints = itertools.chain.from_iterable(
@@ -407,7 +397,7 @@ def ssh_tpu(
407397
*ip_addresses,
408398
connect_kwargs={
409399
'auth_strategy': paramiko.auth_strategy.InMemoryPrivateKey(
410-
user, pkey
400+
ssh_keys.user, pkey
411401
),
412402
# See https://stackoverflow.com/a/59453832
413403
'banner_timeout': 200,

0 commit comments

Comments
 (0)