Skip to content

Commit fd3ddcd

Browse files
committed
add elastic_launch
1 parent 2f1787c commit fd3ddcd

File tree

2 files changed

+83
-15
lines changed

2 files changed

+83
-15
lines changed

torchchat/cli/builder.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19+
from torch.distributed import launcher
20+
1921
from torch.distributed.device_mesh import DeviceMesh
22+
from torch.distributed.elastic.multiprocessing.errors import record
23+
from torch.distributed.elastic.utils.distributed import get_free_port
24+
from torch.distributed.launcher.api import elastic_launch
2025

2126
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
2227

@@ -58,8 +63,8 @@ class BuilderArgs:
5863
distributed: bool = False
5964
num_gpus: int = 1
6065
num_nodes: int = 1
61-
pp_dim: int = 1
62-
tp_dim: int = 1
66+
pp: int = 1
67+
tp: int = 1
6368
is_chat_model: bool = False
6469
prefill_possible: bool = False
6570
dynamic_shapes: bool = False
@@ -164,8 +169,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
164169
distributed = getattr(args, "distributed", False)
165170
num_gpus = getattr(args, "num_gpus", 1)
166171
num_nodes = getattr(args, "num_nodes", 1)
167-
pp_dim = getattr(args, "pp_dim", 1)
168-
tp_dim = getattr(args, "tp_dim", 1)
172+
pp = getattr(args, "pp", 1)
173+
tp = getattr(args, "tp", 1)
169174
return cls(
170175
checkpoint_dir=checkpoint_dir,
171176
checkpoint_path=checkpoint_path,
@@ -182,8 +187,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
182187
distributed=distributed,
183188
num_gpus=num_gpus,
184189
num_nodes=num_nodes,
185-
pp_dim=pp_dim,
186-
tp_dim=tp_dim,
190+
pp=pp,
191+
tp=tp,
187192
is_chat_model=is_chat_model,
188193
dynamic_shapes=getattr(args, "dynamic_shapes", False),
189194
max_seq_length=getattr(args, "max_seq_length", None),
@@ -492,19 +497,70 @@ def _maybe_parellelize_model(
492497

493498

494499
def _load_model(builder_args: BuilderArgs) -> Model:
495-
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
500+
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
496501
if builder_args.gguf_path:
497502
model = _load_model_gguf(builder_args)
498-
elif builder_args.use_distributed:
499-
model = _init_model_on_meta_device(builder_args)
503+
# elif builder_args.use_distributed:
504+
# model = _init_model_on_meta_device(builder_args)
500505
else:
501506
model = _load_model_default(builder_args)
502-
model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims)
507+
# model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims)
503508

504509
model = model.to(device=builder_args.device, dtype=builder_args.precision)
505510
return model.eval()
506511

507512

513+
@record
514+
def run_main(local_rank):
515+
# Add the directory containing the train file to sys.path
516+
train_file_path = Path(__file__).parent.parent.parent / "dist_run.py"
517+
print(f"******* {train_file_path=}")
518+
sys.path.insert(0, os.path.dirname(os.path.abspath(train_file_path)))
519+
520+
# Set environment variables for distributed training
521+
os.environ["LOCAL_RANK"] = str(local_rank)
522+
os.environ["RANK"] = str(
523+
local_rank # + kwargs.get("node_rank", 0) * num_processes_per_node
524+
)
525+
os.environ["WORLD_SIZE"] = str(4 * 1) # num_nodes)
526+
527+
# Execute the train file
528+
with open(train_file_path, "rb") as file:
529+
exec(compile(file.read(), train_file_path, "exec"))
530+
531+
532+
def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
533+
# create programmatic elastic launch
534+
print("Launching distributed inference ...")
535+
536+
num_processes_per_node = 4 # builder_args.num_gpus + 1
537+
538+
lc = launcher.LaunchConfig(
539+
min_nodes=1,
540+
max_nodes=1,
541+
nproc_per_node=num_processes_per_node,
542+
# run_id=str(uuid.uuid4()),
543+
rdzv_backend="c10d",
544+
rdzv_endpoint="localhost:29401",
545+
max_restarts=0,
546+
monitor_interval=1,
547+
)
548+
549+
train_file_path = Path(__file__).parent / "distributed" / "dist_run.py"
550+
551+
elastic_launch(
552+
config=lc,
553+
entrypoint=run_main,
554+
)(train_file_path)
555+
print(
556+
f"Done launching distributed inference on **4 ** {builder_args.num_gpus} GPUs."
557+
)
558+
# role=role, *args, **kwargs)
559+
560+
# assert False, "distributed inference is not supported yet"
561+
# pass
562+
563+
508564
def _initialize_model(
509565
builder_args: BuilderArgs,
510566
quantize,
@@ -513,6 +569,10 @@ def _initialize_model(
513569
support_tensor_subclass: bool = True,
514570
) -> Model:
515571
print("Loading model...")
572+
if builder_args.distributed:
573+
# we part ways here with torchchat cli and move into dist inference
574+
_launch_distributed_inference(builder_args)
575+
return None
516576

517577
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
518578
print("Setting gguf_kwargs for generate.")

torchchat/generate.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,7 @@ def __init__(
245245
f"Using device={self.builder_args.device} {get_device_info(self.builder_args.device)}"
246246
)
247247
set_precision(self.builder_args.precision)
248-
if builder_args.distributed:
249-
print(f"Using distributed={builder_args.distributed}")
250-
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
251-
torch.cuda.set_device(device)
252-
assert False, "Distributed is not supported yet"
248+
253249
self.is_speculative = self.speculative_builder_args.checkpoint_path is not None
254250

255251
if generator_args.chat_mode and not self.builder_args.is_chat_model:
@@ -1205,6 +1201,15 @@ def callback(x, *, done_generating=False):
12051201
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
12061202

12071203

1204+
def _launch_distributed_inference(
1205+
builder_args: BuilderArgs,
1206+
):
1207+
from torch.distributed import launcher
1208+
from torch.distributed.elastic.utils.distributed import get_free_port
1209+
1210+
print("Launching distributed inference within generator")
1211+
1212+
12081213
def main(args):
12091214
builder_args = BuilderArgs.from_args(args)
12101215
speculative_builder_args = BuilderArgs.from_speculative_args(args)
@@ -1221,5 +1226,8 @@ def main(args):
12211226
)
12221227
if torch.cuda.is_available():
12231228
torch.cuda.reset_peak_memory_stats()
1229+
if builder_args.distributed:
1230+
1231+
return
12241232
for _ in gen.chat(generator_args):
12251233
pass

0 commit comments

Comments
 (0)