Skip to content

Commit 27de9f3

Browse files
committed
Added --files_to_copy support for remote_run
1 parent cbcaffa commit 27de9f3

File tree

4 files changed

+143
-11
lines changed

4 files changed

+143
-11
lines changed

automation/script/module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def __init__(self, action_object, automation_file):
6969
'MLC_RENEW_CACHE_ENTRY']
7070

7171
self.host_env_keys = [
72+
"USER",
73+
"HOME",
7274
"GH_TOKEN",
7375
"ftp_proxy",
7476
"FTP_PROXY",

automation/script/remote_run.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def remote_run(self_module, i):
5454
'alias', ''), meta.get(
5555
'uid', '')
5656

57+
#Update meta for selected variation and input
58+
r = update_meta_for_selected_variations(self_module, script, i)
59+
if r['return'] > 0:
60+
return r
61+
62+
remote_run_settings = r['remote_run_settings']
63+
env = r['env']
64+
5765
# Execute the experiment script
5866
mlc_script_input = {
5967
'action': 'run', 'target': 'script'
@@ -65,6 +73,12 @@ def remote_run(self_module, i):
6573
if i.get('remote_pull_mlc_repos', False):
6674
run_cmds.append("mlc pull repo")
6775

76+
files_to_copy = []
77+
env_keys_to_copy = remote_run_settings.get('env_keys_to_copy')
78+
for key in env_keys_to_copy:
79+
if key in env and os.path.exists(env[key]):
80+
files_to_copy.append(env[key])
81+
6882
script_run_cmd = " ".join(mlc_run_cmd.split(" ")[1:])
6983
run_cmds.append(f"mlcr {script_run_cmd}")
7084

@@ -73,6 +87,10 @@ def remote_run(self_module, i):
7387
"password", "skip_host_verify", "ssh_key_file"]:
7488
if i.get(f"remote_{key}"):
7589
remote_inputs[key] = i[f"remote_{key}"]
90+
91+
if files_to_copy:
92+
remote_inputs['files_to_copy'] = files_to_copy
93+
7694

7795
# Execute the remote command
7896
mlc_remote_input = {
@@ -86,3 +104,68 @@ def remote_run(self_module, i):
86104
return r
87105

88106
return {'return': 0}
107+
108+
def update_meta_for_selected_variations(self_module, script, input_params):
109+
metadata = script.meta
110+
script_directory = script.path
111+
script_tags = metadata.get("tags", [])
112+
script_alias = metadata.get('alias', '')
113+
script_uid = metadata.get('uid', '')
114+
tag_values = input_params.get('tags', '').split(",")
115+
variation_tags = [tag[1:] for tag in tag_values if tag.startswith("_")]
116+
117+
run_state = {
118+
'deps': [],
119+
'fake_deps': [],
120+
'parent': None,
121+
'script_id': f"{script_alias},{script_uid}",
122+
'script_variation_tags': variation_tags
123+
}
124+
state_data = {}
125+
env = input_params.get('env', {})
126+
constant_vars = input_params.get('const', {})
127+
constant_state = input_params.get('const_state', {})
128+
129+
remote_run_settings = metadata.get('remote_run', {})
130+
remote_run_settings_default_env = remote_run_settings.get('default_env', {})
131+
for key in remote_run_settings_default_env:
132+
env.setdefault(key, remote_run_settings_default_env[key])
133+
134+
state_data['remote_run'] = remote_run_settings
135+
add_deps_recursive = input_params.get('add_deps_recursive', {})
136+
137+
# Update state with metadata and variations
138+
update_state_result = self_module.update_state_from_meta(
139+
metadata, env, state_data, constant_vars, constant_state,
140+
deps=[],
141+
post_deps=[],
142+
prehook_deps=[],
143+
posthook_deps=[],
144+
new_env_keys=[],
145+
new_state_keys=[],
146+
run_state=run_state,
147+
i=input_params
148+
)
149+
if update_state_result['return'] > 0:
150+
return update_state_result
151+
152+
update_variations_result = self_module._update_state_from_variations(
153+
input_params, metadata, variation_tags, metadata.get(
154+
'variations', {}),
155+
env, state_data, constant_vars, constant_state,
156+
deps=[], # Add your dependencies if needed
157+
post_deps=[], # Add post dependencies if needed
158+
prehook_deps=[], # Add prehook dependencies if needed
159+
posthook_deps=[], # Add posthook dependencies if needed
160+
new_env_keys_from_meta=[], # Add keys from meta if needed
161+
new_state_keys_from_meta=[], # Add state keys from meta if needed
162+
add_deps_recursive=add_deps_recursive,
163+
run_state=run_state,
164+
recursion_spaces=''
165+
)
166+
if update_variations_result['return'] > 0:
167+
return update_variations_result
168+
169+
# Set Docker-specific configurations
170+
remote_run_settings = state_data['remote_run']
171+
return {'return': 0, 'remote_run_settings': remote_run_settings, 'env': env, 'state': state_data}

script/remote-run-commands/customize.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from mlc import utils
22
import os
3-
3+
import subprocess
44

55
def preprocess(i):
66

@@ -13,6 +13,8 @@ def preprocess(i):
1313
# pre_run_cmds = env.get('MLC_SSH_PRE_RUN_CMDS', ['source $HOME/cm/bin/activate'])
1414
pre_run_cmds = env.get('MLC_SSH_PRE_RUN_CMDS', [])
1515

16+
files_to_copy = env.get('MLC_SSH_FILES_TO_COPY', [])
17+
1618
run_cmds = env.get('MLC_SSH_RUN_COMMANDS', [])
1719

1820
run_cmds = pre_run_cmds + run_cmds
@@ -24,23 +26,68 @@ def preprocess(i):
2426
run_cmds[i] = cmd
2527

2628
cmd_string += " ; ".join(run_cmds)
27-
user = env.get('MLC_SSH_USER')
29+
user = env.get('MLC_SSH_USER', os.environ.get('USER'))
2830
password = env.get('MLC_SSH_PASSWORD', None)
2931
host = env.get('MLC_SSH_HOST')
3032
if password:
3133
password_string = " -p " + password
3234
else:
3335
password_string = ""
34-
cmd_extra = ''
36+
37+
38+
ssh_cmd = ["ssh"]
3539

3640
if env.get("MLC_SSH_SKIP_HOST_VERIFY"):
37-
cmd_extra += " -o StrictHostKeyChecking=no"
38-
if env.get("MLC_SSH_KEY_FILE"):
39-
cmd_extra += " -i " + env.get("MLC_SSH_KEY_FILE")
41+
ssh_cmd += ["-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null"]
42+
43+
key_file = env.get("MLC_SSH_KEY_FILE")
44+
if key_file:
45+
ssh_cmd += ["-i", key_file]
46+
47+
ssh_cmd_str = " ".join(ssh_cmd)
48+
49+
50+
51+
ssh_run_command = ssh_cmd_str + " " +user + "@" + host + \
52+
password_string + " '" + cmd_string + "'"
53+
env['MLC_SSH_CMD'] = ssh_run_command
54+
55+
56+
# ---- Use sshpass if password is provided ----
57+
rsync_base = ["rsync", "-avz"]
58+
59+
if password:
60+
rsync_base = ["sshpass", "-p", password] + rsync_base
61+
62+
# ---- Execute copy commands ----
63+
for file in files_to_copy:
64+
cmd = [
65+
"rsync",
66+
"-avz",
67+
"-e", " ".join(ssh_cmd), # rsync expects a single string here
68+
file,
69+
f"{user}@{host}:",
70+
]
71+
72+
print("Executing:", " ".join(cmd))
73+
74+
result = subprocess.run(
75+
cmd,
76+
env=os.environ,
77+
stdout=subprocess.PIPE,
78+
stderr=subprocess.PIPE,
79+
text=True,
80+
)
81+
82+
if result.returncode != 0:
83+
raise RuntimeError(
84+
f"❌ rsync failed for {file}\n"
85+
f"STDOUT:\n{result.stdout}\n"
86+
f"STDERR:\n{result.stderr}"
87+
)
88+
89+
print(f"✅ Copied {file} successfully")
4090

41-
ssh_command = "ssh " + user + "@" + host + \
42-
password_string + cmd_extra + " '" + cmd_string + "'"
43-
env['MLC_SSH_CMD'] = ssh_command
4491

4592
return {'return': 0}
4693

script/remote-run-commands/meta.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ category: Remote automation
55
default_env:
66
MLC_SSH_CLIENT_REFRESH: '10'
77
MLC_SSH_HOST: localhost
8-
MLC_SSH_KEY_FILE: $HOME/.ssh/id_rsa
8+
MLC_SSH_KEY_FILE: <<<HOME>>>/.ssh/id_rsa
99
MLC_SSH_PORT: '22'
10-
MLC_SSH_USER: $USER
1110
input_mapping:
1211
client_refresh: MLC_SSH_CLIENT_REFRESH
1312
host: MLC_SSH_HOST
1413
password: MLC_SSH_PASSWORD
1514
port: MLC_SSH_PORT
1615
run_cmds: MLC_SSH_RUN_COMMANDS
16+
files_to_copy: MLC_SSH_FILES_TO_COPY
1717
skip_host_verify: MLC_SSH_SKIP_HOST_VERIFY
1818
ssh_key_file: MLC_SSH_KEY_FILE
1919
user: MLC_SSH_USER

0 commit comments

Comments
 (0)