Skip to content

Commit 637afae

Browse files
szaherRobotSail
andauthored
fix(torchrun): Omit empty arguments and correct nproc_per_node type (#661)
* fix(torchrun): Omit empty arguments and correct nproc_per_node type The command generation logic is updated to dynamically build the torchrun command, excluding arguments that are empty or None. This prevents them from overriding environment variables, ensuring that torchrun can correctly inherit its configuration. An exception is made for integer arguments where 0 is a valid value. Additionally, the nproc_per_node argument type has been changed from int to str to support special values accepted by PyTorch, such as 'auto', 'gpu', and 'cpu'. Reference: https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L77-L88 Signed-off-by: Saad Zaher <szaher@redhat.com> * only dynamically add torchrun args & change rdzv_id type to str Signed-off-by: Saad Zaher <szaher@redhat.com> * fix smoke tests Signed-off-by: Saad Zaher <szaher@redhat.com> * Enable both dtypes str, int for nproc_per_node, rdzv_id Signed-off-by: Saad Zaher <szaher@redhat.com> * Use python3.11 style for pydatnic model Signed-off-by: Saad Zaher <szaher@redhat.com> * add all torchrun args and validate them Signed-off-by: Saad Zaher <szaher@redhat.com> * Remove non-required dependencies Signed-off-by: Saad Zaher <szaher@redhat.com> * update datatypes only Signed-off-by: Saad Zaher <szaher@redhat.com> * replace _ with - when passing torchrun args Signed-off-by: Saad Zaher <szaher@redhat.com> * make nproc_per_node to only accept gpu or int Signed-off-by: Saad Zaher <szaher@redhat.com> * add master_{addr, port} validate args Signed-off-by: Saad Zaher <szaher@redhat.com> * check for not set or empty rdzv endpoint Signed-off-by: Saad Zaher <szaher@redhat.com> * fix formatting error Signed-off-by: Saad Zaher <szaher@redhat.com> * Update src/instructlab/training/config.py Signed-off-by: Saad Zaher <szaher@redhat.com> * Update tests/smoke/test_train.py Signed-off-by: Saad Zaher <szaher@redhat.com> * Update src/instructlab/training/main_ds.py Signed-off-by: Saad Zaher <szaher@redhat.com> * fixes indentation Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> * formatting * add standalone as the fallback when neither master_addr nor rdzv_endpoint are provided Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> * clarify rdzv-backend arg --------- Signed-off-by: Saad Zaher <szaher@redhat.com> Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Co-authored-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
1 parent 2c8d676 commit 637afae

File tree

2 files changed

+67
-23
lines changed

2 files changed

+67
-23
lines changed

src/instructlab/training/config.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import List, Literal, Optional
1010

1111
# Third Party
12-
from pydantic import BaseModel, ConfigDict, Field
12+
from pydantic import BaseModel, ConfigDict, Field, model_validator
1313

1414

1515
# public API
@@ -70,16 +70,33 @@ class DataProcessArgs(BaseModel):
7070
# public API
7171
class TorchrunArgs(BaseModel):
7272
"""
73-
Representation of the arguments being used by torchrun.
74-
The full list of arguments can be found here:
75-
https://pytorch.org/docs/stable/elastic/run.html#definitions
73+
Arguments for torchrun (https://pytorch.org/docs/stable/elastic/run.html#definitions)
74+
75+
Precedence order: arg > env > defaults
76+
Ensures that either `rdzv_endpoint` OR both `master_addr` and `master_port`
77+
are provided, but not both.
7678
"""
7779

78-
nproc_per_node: int
80+
# Core distributed training arguments
81+
nproc_per_node: Literal["gpu"] | int
7982
nnodes: int
8083
node_rank: int
81-
rdzv_id: int
82-
rdzv_endpoint: str
84+
rdzv_id: str | int
85+
86+
# Rendezvous / master configuration
87+
rdzv_endpoint: Optional[str] = None
88+
master_addr: Optional[str] = None
89+
master_port: Optional[int] = None
90+
91+
model_config = ConfigDict(extra="ignore")
92+
93+
@model_validator(mode="after")
94+
def validate_endpoint_config(self):
95+
if self.rdzv_endpoint and self.master_addr:
96+
raise ValueError(
97+
"Provide either `rdzv_endpoint` OR both `master_addr` and `master_port`, not both."
98+
)
99+
return self
83100

84101

85102
# public API

src/instructlab/training/main_ds.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -480,27 +480,54 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
480480
if not os.path.exists(train_args.ckpt_output_dir):
481481
os.makedirs(train_args.ckpt_output_dir, exist_ok=True)
482482

483+
# build distributed training command
483484
command = [
484485
"torchrun",
486+
f"--nproc-per-node={torch_args.nproc_per_node}",
485487
f"--nnodes={torch_args.nnodes}",
486-
f"--node_rank={torch_args.node_rank}",
487-
f"--nproc_per_node={torch_args.nproc_per_node}",
488-
f"--rdzv_id={torch_args.rdzv_id}",
489-
f"--rdzv_endpoint={torch_args.rdzv_endpoint}",
490-
__file__,
491-
f"--model_name_or_path={train_args.model_path}",
492-
f"--data_path={train_args.data_output_dir}/data.jsonl",
493-
f"--output_dir={train_args.ckpt_output_dir}",
494-
f"--num_epochs={train_args.num_epochs}",
495-
f"--effective_batch_size={train_args.effective_batch_size}",
496-
f"--learning_rate={train_args.learning_rate}",
497-
f"--num_warmup_steps={train_args.warmup_steps}",
498-
f"--save_samples={train_args.save_samples}",
499-
f"--log_level={train_args.log_level}",
500-
f"--max_batch_len={train_args.max_batch_len}",
501-
f"--seed={train_args.random_seed}",
488+
f"--node-rank={torch_args.node_rank}",
489+
f"--rdzv-id={torch_args.rdzv_id}",
502490
]
503491

492+
# validation should have already caught the mutually exclusive case earlier, but here we check
493+
# anyway just to be extra sure since Python is not type-safe and validation can bypassed (e.g. during testing)
494+
if torch_args.master_addr and torch_args.rdzv_endpoint:
495+
raise ValueError(
496+
"`torch_args.master_addr` and `torch_args.rdzv_endpoint` cannot be passed at the same time; please pass only one"
497+
)
498+
499+
if torch_args.master_addr:
500+
command += [
501+
f"--master-addr={torch_args.master_addr}",
502+
"--rdzv-backend=static",
503+
]
504+
command += (
505+
[f"--master-port={torch_args.master_port}"]
506+
if torch_args.master_port
507+
else []
508+
)
509+
elif torch_args.rdzv_endpoint:
510+
command += [f"--rdzv-endpoint={torch_args.rdzv_endpoint}"]
511+
else:
512+
command += ["--standalone"]
513+
514+
command.extend(
515+
[
516+
__file__,
517+
f"--model_name_or_path={train_args.model_path}",
518+
f"--data_path={train_args.data_output_dir}/data.jsonl",
519+
f"--output_dir={train_args.ckpt_output_dir}",
520+
f"--num_epochs={train_args.num_epochs}",
521+
f"--effective_batch_size={train_args.effective_batch_size}",
522+
f"--learning_rate={train_args.learning_rate}",
523+
f"--num_warmup_steps={train_args.warmup_steps}",
524+
f"--save_samples={train_args.save_samples}",
525+
f"--log_level={train_args.log_level}",
526+
f"--max_batch_len={train_args.max_batch_len}",
527+
f"--seed={train_args.random_seed}",
528+
]
529+
)
530+
504531
if train_args.chat_tmpl_path is not None:
505532
command.append(f"--chat-tmpl-path={train_args.chat_tmpl_path}")
506533

0 commit comments

Comments
 (0)