Skip to content

Commit b0a045f

Browse files
authored
Make distributed mode automatic (#24)
* Make distributed mode automatic * Update help text for --workers argument
1 parent f41b9b2 commit b0a045f

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

csub.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def build_parser() -> argparse.ArgumentParser:
4848
parser.add_argument("-i", "--image", type=str, help="Override RUNAI_IMAGE from the env file")
4949
parser.add_argument("-p", "--port", type=int, help="Expose a container port")
5050
parser.add_argument("--train", action="store_true", help="Submit as a training workload")
51-
parser.add_argument("--distributed", action="store_true", help="Submit a distributed workload")
52-
parser.add_argument("--workers", default=0, type=int, help="Only read for distributed workloads. Number of nodes IN ADDITION to the master node. I.e., the total number of nodes is the number of workers + 1 (the master node)")
51+
parser.add_argument("--workers", default=0, type=int, help="Number of nodes IN ADDITION to the master node. I.e., the total number of nodes is the number of workers + 1 (the master node)")
5352
parser.add_argument("--dry", action="store_true", help="Print the generated runai command")
5453
parser.add_argument("--env-file", type=str, default=DEFAULT_ENV_FILE, help="Path to the .env file (default: .env in the repo root)")
5554
parser.add_argument("--sync-secret-only", action="store_true", help="Create/refresh the Kubernetes secret and exit without submitting a job")
@@ -66,8 +65,10 @@ def build_parser() -> argparse.ArgumentParser:
6665
def build_runai_command(
6766
args: argparse.Namespace, env: Dict[str, str]
6867
) -> Tuple[List[str], str]:
69-
assert args.train + args.distributed <= 1, "Choose --train or --distributed but not both"
70-
68+
distributed = args.workers > 0
69+
if not args.train and distributed:
70+
args.train |= distributed
71+
print("Forcing non-interactive as distributed")
7172
job_name = (
7273
args.name
7374
or f"{env['LDAP_USERNAME']}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
@@ -129,7 +130,7 @@ def build_runai_command(
129130
shell_command += f" && {user_command}"
130131

131132
cmd: List[str] = ["runai"]
132-
cmd.extend(["submit-dist", "pytorch"] if args.distributed else ["submit"])
133+
cmd.extend(["submit-dist", "pytorch"] if distributed else ["submit"])
133134
cmd.extend([
134135
"--name",
135136
job_name,
@@ -157,7 +158,7 @@ def build_runai_command(
157158
if args.memory:
158159
cmd.extend(["--memory", args.memory])
159160

160-
if not args.train and not args.distributed:
161+
if not args.train:
161162
cmd.append("--interactive")
162163
else:
163164
cmd.extend(["--backoff-limit", str(args.backofflimit)])
@@ -172,10 +173,10 @@ def build_runai_command(
172173

173174
if args.node_type:
174175
cmd.extend(["--node-pools", args.node_type])
175-
if args.node_type in {"h200", "h100"} and not args.train and not args.distributed:
176+
if args.node_type in {"h200", "h100"} and not args.train:
176177
cmd.append("--preemptible")
177178

178-
if args.distributed:
179+
if distributed:
179180
cmd.extend([
180181
"--workers", str(args.workers),
181182
"--annotation", "k8s.v1.cni.cncf.io/networks=kube-system/roce",

0 commit comments

Comments
 (0)