Skip to content

Commit 5a03d26

Browse files
authored
[cli] support run as module option (#6135)
1 parent cc40fe0 commit 5a03d26

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

colossalai/cli/launcher/__init__.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@
6464
"This will be converted to --arg1=1 --arg2=2 during execution",
6565
)
6666
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
67-
@click.argument("user_script", type=str)
67+
@click.option("-m", type=str, default=None, help="run library module as a script (terminates option list)")
68+
@click.argument("user_script", type=str, required=False, default=None)
6869
@click.argument("user_args", nargs=-1)
6970
def run(
7071
host: str,
@@ -77,8 +78,9 @@ def run(
7778
master_port: int,
7879
extra_launch_args: str,
7980
ssh_port: int,
81+
m: str,
8082
user_script: str,
81-
user_args: str,
83+
user_args: tuple,
8284
) -> None:
8385
"""
8486
To launch multiple processes on a single node or multiple nodes via command line.
@@ -102,9 +104,24 @@ def run(
102104
# run with hostfile excluding the hosts selected
103105
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
104106
"""
105-
if not user_script.endswith(".py"):
106-
click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
107-
exit()
107+
if m is not None:
108+
if m.endswith(".py"):
109+
click.echo(f"Error: invalid Python module {m}. Did you use a wrong option? Try colossalai run --help")
110+
exit()
111+
if user_script is not None:
112+
user_args = (user_script,) + user_args
113+
user_script = m
114+
m = True
115+
else:
116+
if user_script is None:
117+
click.echo("Error: missing script argument. Did you use a wrong option? Try colossalai run --help")
118+
exit()
119+
if not user_script.endswith(".py"):
120+
click.echo(
121+
f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help"
122+
)
123+
exit()
124+
m = False
108125

109126
args_dict = locals()
110127
args = Config(args_dict)

colossalai/cli/launcher/run.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def get_launch_command(
113113
user_args: List[str],
114114
node_rank: int,
115115
num_nodes: int,
116+
run_as_module: bool,
116117
extra_launch_args: str = None,
117118
) -> str:
118119
"""
@@ -155,6 +156,8 @@ def _arg_dict_to_list(arg_dict):
155156

156157
torch_version = version.parse(torch.__version__)
157158
assert torch_version.major >= 1
159+
if torch_version.major < 2 and run_as_module:
160+
raise ValueError("Torch version < 2.0 does not support running as module")
158161

159162
if torch_version.major == 1 and torch_version.minor < 9:
160163
# torch distributed launch cmd with torch < 1.9
@@ -198,7 +201,10 @@ def _arg_dict_to_list(arg_dict):
198201
]
199202
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
200203

201-
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
204+
cmd += _arg_dict_to_list(extra_launch_args)
205+
if run_as_module:
206+
cmd.append("-m")
207+
cmd += [user_script] + user_args
202208
cmd = " ".join(cmd)
203209
return cmd
204210

@@ -294,6 +300,7 @@ def launch_multi_processes(args: Config) -> None:
294300
user_args=args.user_args,
295301
node_rank=node_id,
296302
num_nodes=len(active_device_pool),
303+
run_as_module=args.m,
297304
extra_launch_args=args.extra_launch_args,
298305
)
299306
runner.send(hostinfo=hostinfo, cmd=cmd)

0 commit comments

Comments
 (0)