Skip to content

Commit bf79697

Browse files
committed
working, can now launch from cli
1 parent fd3ddcd commit bf79697

File tree

2 files changed

+41
-23
lines changed

2 files changed

+41
-23
lines changed

dist_run.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
2121
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
2222

23-
from torchchat.distributed.logging_utils import SingletonLogger
24-
2523
# TODO - these are not distributed specific, consider moving to new package
2624
from torchchat.distributed.checkpoint_utils import (
2725
get_hf_config_file,
2826
load_weights_from_hf_format,
2927
load_weights_from_torchchat_format,
3028
)
29+
30+
from torchchat.distributed.logging_utils import SingletonLogger
3131
from torchchat.distributed.utils import (
3232
bytes_to_readable,
3333
Color as color,
@@ -153,7 +153,9 @@ def _load_model_weights(
153153
# This format stands for:
154154
# single binary file, OR
155155
# multiple binary files without index files.
156-
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
156+
load_weights_from_torchchat_format(
157+
stage_module, distribution, device, model_config
158+
)
157159
else:
158160
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
159161

@@ -304,7 +306,7 @@ def _cleanup():
304306

305307

306308
def main(args):
307-
model_name = args.model_name
309+
model_name = "llama3" # args.model_name
308310
pp_degree = args.pp
309311

310312
rank, world_size = _init_distributed()
@@ -590,12 +592,14 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
590592

591593
if __name__ == "__main__":
592594
parser = argparse.ArgumentParser()
593-
parser.add_argument(
595+
"""parser.add_argument(
594596
"model_name",
595597
type=str,
598+
default="llama3",
596599
help="Name of the model to load",
597-
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
600+
# choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
598601
)
602+
"""
599603
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
600604
parser.add_argument(
601605
"--ntokens",

torchchat/cli/builder.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -510,23 +510,29 @@ def _load_model(builder_args: BuilderArgs) -> Model:
510510
return model.eval()
511511

512512

513-
@record
514-
def run_main(local_rank):
515-
# Add the directory containing the train file to sys.path
516-
train_file_path = Path(__file__).parent.parent.parent / "dist_run.py"
517-
print(f"******* {train_file_path=}")
518-
sys.path.insert(0, os.path.dirname(os.path.abspath(train_file_path)))
513+
import importlib.util
514+
import subprocess
515+
519516

520-
# Set environment variables for distributed training
521-
os.environ["LOCAL_RANK"] = str(local_rank)
522-
os.environ["RANK"] = str(
523-
local_rank # + kwargs.get("node_rank", 0) * num_processes_per_node
517+
def run_script(script_path, *args):
518+
# Construct the command to run the script
519+
cmd = [sys.executable, script_path] + list(args)
520+
521+
# Run the script as a subprocess
522+
process = subprocess.Popen(
523+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
524524
)
525-
os.environ["WORLD_SIZE"] = str(4 * 1) # num_nodes)
526525

527-
# Execute the train file
528-
with open(train_file_path, "rb") as file:
529-
exec(compile(file.read(), train_file_path, "exec"))
526+
# Stream the output in real-time
527+
for line in process.stdout:
528+
print(line, end="")
529+
for line in process.stderr:
530+
print(line, end="", file=sys.stderr)
531+
532+
# Wait for the process to complete and get the return code
533+
return_code = process.wait()
534+
if return_code != 0:
535+
raise subprocess.CalledProcessError(return_code, cmd)
530536

531537

532538
def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
@@ -546,12 +552,20 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
546552
monitor_interval=1,
547553
)
548554

549-
train_file_path = Path(__file__).parent / "distributed" / "dist_run.py"
555+
train_file_path = Path(__file__).parent.parent.parent / "dist_run.py"
556+
print(f"train_file_path: {train_file_path}")
557+
# import argparse
558+
559+
# parser2 = argparse.ArgumentParser()
560+
561+
# args = parser2.parse_args()
562+
args = []
563+
print(f"args: {args}")
550564

551565
elastic_launch(
552566
config=lc,
553-
entrypoint=run_main,
554-
)(train_file_path)
567+
entrypoint=run_script,
568+
)(train_file_path, *args)
555569
print(
556570
f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs."
557571
)

0 commit comments

Comments
 (0)